{# ----------------------------------------------------------------------------
 # SymForce - Copyright 2022, Skydio, Inc.
 # This source code is under the Apache 2.0 license found in the LICENSE file.
 # ---------------------------------------------------------------------------- #}

#pragma once

#include <Eigen/Core>

#include <sym/ops/lie_group_ops.h>

namespace sym {

/**
 * C++ LieGroupOps implementation for matrices.
 */
template <typename ScalarType, int Rows, int Cols>
struct LieGroupOps<Eigen::Matrix<ScalarType, Rows, Cols>> : public internal::LieGroupOpsBase<Eigen::Matrix<ScalarType, Rows, Cols>, ScalarType> {
  using Scalar = ScalarType;
  using T = Eigen::Matrix<Scalar, Rows, Cols>;
  static_assert(std::is_floating_point<ScalarType>::value, "");

  static constexpr int32_t TangentDim() {
    if (Rows == Eigen::Dynamic) {
      return Eigen::Dynamic;
    }

    if (Cols == Eigen::Dynamic) {
      return Eigen::Dynamic;
    }

    return Rows * Cols;
  }

  using TangentVec = Eigen::Matrix<Scalar, TangentDim(), 1>;
  static T FromTangent(const TangentVec& vec, const Scalar epsilon) {
    (void)epsilon; // unused

    return Eigen::Map<const T>(vec.data(), Rows, Cols);
  }
  static TangentVec ToTangent(const T& a, const Scalar epsilon) {
    (void)epsilon; // unused

    return Eigen::Map<const TangentVec>(a.data(), a.size());
  }
  static T Retract(const T& a, const TangentVec& vec, const Scalar epsilon) {
    (void)epsilon; // unused

    return a + Eigen::Map<const T>(vec.data(), a.rows(), a.cols());
  }
  static TangentVec LocalCoordinates(const T& a, const T& b, const Scalar epsilon) {
    (void)epsilon; // unused

    return (Eigen::Map<const TangentVec>(b.data(), b.size()) -
            Eigen::Map<const TangentVec>(a.data(), a.size()));
  }
  static T Interpolate(const T&a, const T& b, const Scalar alpha, const Scalar epsilon) {
    return Retract(a, alpha * LocalCoordinates(a, b, epsilon), epsilon);
  }
};

}  // namespace sym

// Explicit instantiation
{% for scalar in scalar_types %}
{% for i in range(1, 10) %}
extern template struct sym::LieGroupOps<Eigen::Matrix<{{ scalar }}, {{ i }}, 1>>;
{% endfor %}
{% for i in range(2, 10) %}
extern template struct sym::LieGroupOps<Eigen::Matrix<{{ scalar }}, {{ i }}, {{ i }}>>;
{% endfor %}
{% endfor %}
